package ltd.nullpointer.tcp.core;

import com.google.common.collect.BiMap;
import com.google.common.collect.HashBiMap;
import com.google.common.collect.Maps;
import io.netty.channel.Channel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Component;

import java.util.Map;
import java.util.Observable;
import java.util.concurrent.ConcurrentHashMap;

//import com.boot2.npiot.api.constant.TCPEnum.ErrorCode;
//import com.boot2.npiot.api.tcp.exception.MessageResolverException;
//import com.boot2.npiot.dao.jpa.DeviceJpaDao;
//import com.boot2.npiot.kafka.constant.KafkaTopicConstant;
//import com.boot2.npiot.kafka.producer.KafkaProducer;
//import com.boot2.npiot.model.jpa.Device;

/**
 * 
 * @ClassName: SessionManager
 * @Description: tcp session管理
 * @author zhangweilin
 * @date 2017年11月9日 下午12:45:54
 *
 */
@Component
public class SessionManager extends Observable {
	private static final Logger logger = LoggerFactory.getLogger(SessionManager.class);

//	@Autowired
//	KafkaProducer<Object> kafkaProducer;
//
//	@Autowired
//	DeviceJpaDao deviceJpaDao;

	/**
	 * 会话map
	 */
	public final BiMap<Channel, String> sessionMap;

	/**
	 * aes通讯的key与设备对应map
	 */
	private final BiMap<Channel, String> aesKeyMap;

	/**
	 * product与channel的关系
	 */
	public final BiMap<Channel, String> productChannelMap;

	/**
	 * channel与密钥之间调用次数引用计数，只内部使用
	 */
	private final Map<Channel, Integer> refAesKeyCountMap;

	public SessionManager() {
		BiMap<Channel, String> temp = HashBiMap.create();
		sessionMap = Maps.synchronizedBiMap(temp);

		BiMap<Channel, String> temp2 = HashBiMap.create();
		aesKeyMap = Maps.synchronizedBiMap(temp2);

		BiMap<Channel, String> temp3 = HashBiMap.create();
		productChannelMap = Maps.synchronizedBiMap(temp3);

		refAesKeyCountMap = new ConcurrentHashMap<>();
	}

	public void add(String deviceId, Channel channel) {
//		Device device = deviceJpaDao.findByDeviceSn(deviceId);
//		if (null == device) {
//			throw new MessageResolverException(ErrorCode.err10014.getErrCode(), ErrorCode.err10014.getName() + ",上线失败，deviceId: " + deviceId);
//		}
		if (channel != null) {
			// TODO 为了测试，临时允许重复上线
			// sessionMap.remove(channel, deviceId);
			// System.out.println("重复上线");
			// sessionMap.put(channel, deviceId);
			if (!sessionMap.containsValue(deviceId)) {
				sessionMap.put(channel, deviceId);
			} else {
				// TODO 测试时允许重复上线
//				throw new MessageResolverException(TCPEnum.ErrorCode.err10013.getErrCode(), TCPEnum.ErrorCode.err10013.getName() + ",deviceId: " + deviceId);
			}
			return;
		}
		logger.trace("Sessions after add: {}", sessionMap.size());
	}

	public void add(String deviceId, String aesKey, Channel channel) {
		add(deviceId, channel);
		addByAESKey(aesKey, channel);
	}

	public void addByAESKey(String aesKey, Channel channel) {
		if (channel != null) {
			aesKeyMap.put(channel, aesKey);
			return;
		}
		logger.trace("aesKeyMap after add: {}", aesKeyMap.size());
	}

	public void addByProductId(String productId, Channel channel) {
		if (channel != null) {
			productChannelMap.put(channel, productId);
			return;
		}
		logger.trace("productChannelMap after add: {}", productChannelMap.size());
	}

	public String getDeviceId(Channel channel) {
		return sessionMap.get(channel);
	}

	public String getProductId(Channel channel) {
		return productChannelMap.get(channel);
	}

	public synchronized String getAesKey(Channel channel) {
		Integer count = refAesKeyCountMap.get(channel);
		count = count == null ? 0 : count;
		if (count < 2) {
			count++;
			refAesKeyCountMap.put(channel, count);
			return null;
		}
		return aesKeyMap.get(channel);
	}

	/**
	 * 获取在线的map
	 * @return
	 */
	public Map<String, Channel> getOnlineMap() {
		BiMap<String, Channel> mapper = sessionMap.inverse();
		return mapper;
	}

	public Channel getChannelByDeviceId(String deviceId) {
		return getOnlineMap().get(deviceId);
	}

	public Channel getChannelByAESKey(String aesKey) {
		BiMap<String, Channel> mapper = aesKeyMap.inverse();
		return mapper.get(aesKey);
	}

	public Channel remove(Channel channel) {
		final String deviceId = sessionMap.remove(channel);
		final String aesKey = aesKeyMap.remove(channel);
		productChannelMap.remove(channel);
//		kafkaProducer.send(KafkaTopicConstant.tcp_channel_remove, deviceId);
		logger.trace("Remove channel to deviceId = {}, sessions after remove: {}", deviceId, sessionMap.size());
		return channel;
	}

	/**
	 * 收到报文时调用，添加和更新channel
	 * 
	 * @param deviceId
	 * @param channel
	 */
	public synchronized void addOrUpdateChannel(String deviceId, Channel channel) {
		Channel sessionChannel = getChannelByDeviceId(deviceId);
		if (null == sessionChannel) {
			add(deviceId, channel);
		} else {
			if (!channel.equals(sessionChannel)) {
				remove(sessionChannel);
				add(deviceId, channel);
			}
		}
	}

	/**
	 * 收到报文时调用，添加和更新channel
	 * 
	 * @param aesKey
	 * @param channel
	 */
	public synchronized void addOrUpdateChannelByAESKey(String aesKey, Channel channel) {
		Channel sessionChannel = getChannelByAESKey(aesKey);
		if (null == sessionChannel) {
			addByAESKey(aesKey, channel);
		} else {
			if (!channel.equals(sessionChannel)) {
				remove(sessionChannel);
				addByAESKey(aesKey, channel);
			}
		}
	}

	/**
	 * 收到报文时调用，添加和更新channel
	 * 
	 * @param deviceId
	 * @param channel
	 */
	public synchronized void addOrUpdateChannel(String deviceId, String aesKey, Channel channel) {
		addOrUpdateChannel(deviceId, channel);
		addOrUpdateChannelByAESKey(aesKey, channel);
	}

	public boolean isOnline(String deviceId) {
		return getChannelByDeviceId(deviceId) != null;
	}

	public int getAesChannelSize() {
		return aesKeyMap.size();
	}

}
